{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8190b728-f3ce-41ba-b28e-05666307e575",
      "metadata": {
        "id": "ooskkmedwdL3"
      },
      "outputs": [],
      "source": [
        "%load_ext autoreload\n",
        "%autoreload 2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5e7f71d5-533e-4d06-b4cb-e705771aaf9c",
      "metadata": {
        "id": "VrmKqakVwdL3"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "xla_flags = os.environ.get(\"XLA_FLAGS\", \"\")\n",
        "xla_flags += \" --xla_gpu_triton_gemm_any=True\"\n",
        "os.environ[\"XLA_FLAGS\"] = xla_flags\n",
        "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n",
        "os.environ[\"MUJOCO_GL\"] = \"egl\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "3891e589-7f3c-479f-a581-4b3755ff3a40",
      "metadata": {
        "id": "qtdocYbbwdL3"
      },
      "outputs": [],
      "source": [
        "import functools\n",
        "import json\n",
        "from datetime import datetime\n",
        "\n",
        "import re\n",
        "import pandas as pd\n",
        "import jax\n",
        "import jax.numpy as jp\n",
        "import matplotlib.pyplot as plt\n",
        "import mediapy as media\n",
        "import mujoco\n",
        "import numpy as np\n",
        "from brax.training.agents.ppo import networks as ppo_networks\n",
        "from brax.training.agents.ppo import train as ppo\n",
        "from etils import epath\n",
        "from flax.training import orbax_utils\n",
        "from IPython.display import clear_output, display\n",
        "from orbax import checkpoint as ocp\n",
        "\n",
        "from mujoco_playground import registry, wrapper\n",
        "from mujoco_playground.config import locomotion_params, manipulation_params"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6e24bf9b-ca17-4746-aa61-eb99e00fb711",
      "metadata": {
        "id": "kYD426g1wdL3"
      },
      "outputs": [],
      "source": [
        "def clean_string_for_filename(s):\n",
        "  s = s.strip()\n",
        "  s = s.replace(\" \", \"_\")\n",
        "  s = re.sub(r'[^\\w_-]', '', s)\n",
        "  return s\n",
        "\n",
        "\n",
        "df = {\n",
        "    'env': [],\n",
        "    'seed': [],\n",
        "    'training/walltime': [],\n",
        "    'step': [],\n",
        "    'eval/episode_reward': []\n",
        "}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "519e55a1-b65a-4185-84c4-0f8ae8db2427",
      "metadata": {
        "id": "nxf5wmL3wdL3"
      },
      "outputs": [],
      "source": [
        "def run_for_env(env_name):\n",
        "    n_seeds = 3\n",
        "    env_cfg = registry.get_default_config(env_name)\n",
        "    try:\n",
        "      randomizer = registry.get_domain_randomizer(env_name)\n",
        "    except:\n",
        "      randomizer = None\n",
        "    if env_name in registry.locomotion.ALL_ENVS:\n",
        "        ppo_params = locomotion_params.brax_ppo_config(env_name)\n",
        "    else:\n",
        "        ppo_params = manipulation_params.brax_ppo_config(env_name)\n",
        "\n",
        "    training_params = dict(ppo_params)\n",
        "    del training_params[\"network_factory\"]\n",
        "\n",
        "    for i in range(n_seeds):\n",
        "        print(f'Running seed: {i}')\n",
        "        training_params['seed'] = i\n",
        "\n",
        "        x_data, y_data, y_dataerr = [], [], []\n",
        "        times = [datetime.now()]\n",
        "\n",
        "        def progress(num_steps, metrics):\n",
        "            clear_output(wait=True)\n",
        "\n",
        "            times.append(datetime.now())\n",
        "            x_data.append(num_steps)\n",
        "            y_data.append(metrics[\"eval/episode_reward\"])\n",
        "            y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n",
        "\n",
        "            plt.xlim([0, training_params[\"num_timesteps\"] * 1.25])\n",
        "            plt.xlabel(\"# environment steps\")\n",
        "            plt.ylabel(\"reward per episode\")\n",
        "            plt.title(f\"y={y_data[-1]:.3f}\")\n",
        "            plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n",
        "\n",
        "            df['env'].append(env_name)\n",
        "            df['seed'].append(training_params['seed'])\n",
        "            df['training/walltime'].append((times[-1] - times[0]).total_seconds())\n",
        "            df['step'].append(num_steps)\n",
        "            df['eval/episode_reward'].append(metrics[\"eval/episode_reward\"])\n",
        "\n",
        "            display(plt.gcf())\n",
        "\n",
        "        train_fn = functools.partial(\n",
        "          ppo.train,\n",
        "          **training_params,\n",
        "          network_factory=functools.partial(\n",
        "              ppo_networks.make_ppo_networks,\n",
        "              **ppo_params.network_factory\n",
        "          ),\n",
        "          progress_fn=progress,\n",
        "          wrap_env_fn=wrapper.wrap_for_brax_training,\n",
        "          randomization_fn=randomizer,\n",
        "        )\n",
        "\n",
        "        env = registry.load(env_name, config=env_cfg)\n",
        "        eval_env = registry.load(env_name, config=env_cfg)\n",
        "        make_inference_fn, params, _ = train_fn(environment=env, eval_env=eval_env)\n",
        "\n",
        "        print(f\"time to jit: {times[1] - times[0]}\")\n",
        "        print(f\"time to train: {times[-1] - times[1]}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "400c3f28-574c-42ab-9d6d-aea793343c7f",
      "metadata": {
        "scrolled": true,
        "id": "xjwdGIyiwdL3"
      },
      "outputs": [],
      "source": [
        "run_for_env(\"Go1JoystickFlatTerrain\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "331f73b6-1876-4179-b693-89c9f4b4cfa2",
      "metadata": {
        "id": "5GfoYGAnwdL3"
      },
      "outputs": [],
      "source": [
        "run_for_env(\"LeapCubeReorient\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "51a7cb56-412f-45b5-801a-710a3e5c9216",
      "metadata": {
        "id": "BOajpcg-wdL3"
      },
      "outputs": [],
      "source": [
        "run_for_env(\"G1Joystick\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b62cb938-3177-4950-9caa-af7f196d3515",
      "metadata": {
        "id": "Z90ePxNGwdL3"
      },
      "outputs": [],
      "source": [
        "num_devices = len(jax.devices())\n",
        "device_kind = jax.devices()[0].device_kind\n",
        "device_topo = f'{num_devices}x {device_kind}'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6b00d716-5c13-49dc-9a8f-866c545aac27",
      "metadata": {
        "id": "2SdyS-ZPwdL3"
      },
      "outputs": [],
      "source": [
        "del df['device_topo']\n",
        "df = pd.DataFrame(df)\n",
        "df['device_topo'] = device_topo"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6fcdb030-105b-4d9e-94af-ef99a4fce86b",
      "metadata": {
        "id": "t6JJiPXtwdL3"
      },
      "outputs": [],
      "source": [
        "df.to_csv('../data/' + clean_string_for_filename(device_topo) + '.csv')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5e2a3216-f70b-433c-ae96-be224d2acd63",
      "metadata": {
        "id": "pAfy75mnwdL3"
      },
      "outputs": [],
      "source": [
        "# df = {k: list(v.values()) for k, v in df.to_dict().items()}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4386f582-701f-48c7-aa0c-654c052c145f",
      "metadata": {
        "id": "95lfVyXWwdL3"
      },
      "outputs": [],
      "source": [
        "df"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "dabc3667-e82c-4c7c-93f0-f4aed1408baf",
      "metadata": {
        "id": "PNIGwxsgwdL3"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
